/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <benchmark/benchmark.h>
#include "execute_graph_builder.h"
#include "common_reg.h"
#include "ge/transformer_utils.h"
#include "shape_wrapper.h"
namespace ge {
#define BENCHMARK_TEST(F)          \
void F(benchmark::State &state); \
BENCHMARK(F);                    \
void F(benchmark::State &state)


std::unique_ptr<ExecuteGraph> BuildGraph() {
  ExecuteGraphBuilder gb;
  auto registry = RegOpDefs(gb);
  auto nb = gb.GetNodeBuilder(gb.AddNode());

  nb->SetOpDef(registry.matmul)
  .SetName("HelloMatMul")
  .SetInputName(0, "x1")
  .SetInputName(1, "x2")
  .SetOutputName(0, "y");

  auto graph = gb.Build();
  auto td = graph->GetNodeByIndex(0)->GetOpDesc()->MutableInputDesc(0);
  td->SetDataType(DT_FLOAT16);
  td->SetFormat(FORMAT_NC1HWC0);
  td->SetShape(Shape({8,3,224,224}));
  td->SetOriginDataType(DT_FLOAT16);
  td->SetOriginFormat(FORMAT_NCHW);
  td->SetOriginShape(Shape({8,3,224,224}));

  td = graph->GetNodeByIndex(0)->GetOpDesc()->MutableInputDesc(1);
  td->SetDataType(DT_FLOAT16);
  td->SetFormat(FORMAT_NC1HWC0);
  td->SetShape(Shape({8,3,224,224}));
  td->SetOriginDataType(DT_FLOAT16);
  td->SetOriginFormat(FORMAT_NCHW);
  td->SetOriginShape(Shape({8,3,224,224}));

  td = graph->GetNodeByIndex(0)->GetOpDesc()->MutableOutputDesc(0);
  td->SetDataType(DT_FLOAT16);
  td->SetFormat(FORMAT_NC1HWC0);
  td->SetShape(Shape({8,3,224,224}));
  td->SetOriginDataType(DT_FLOAT16);
  td->SetOriginFormat(FORMAT_NCHW);
  td->SetOriginShape(Shape({8,3,224,224}));

  return std::move(graph);
}

BENCHMARK_TEST(TransformerInit) {
  auto graph = BuildGraph();
  graph = BuildGraph();
  auto op_desc = graph->GetNodeByIndex(0)->GetOpDesc();

  NodeShapeTransUtils transformer(op_desc);
  for (auto _ : state) {
    transformer.Init();
  }
}
BENCHMARK_TEST(TransformerCache) {
  auto graph = BuildGraph();
  graph = BuildGraph();
  auto op_desc = graph->GetNodeByIndex(0)->GetOpDesc();

  NodeShapeTransUtils transformer(op_desc);
  transformer.Init();
  for (auto _ : state) {
    transformer.CatchFormatAndShape();
  }
}

BENCHMARK_TEST(TransformerUpdate) {
  auto graph = BuildGraph();
  graph = BuildGraph();
  auto op_desc = graph->GetNodeByIndex(0)->GetOpDesc();

  NodeShapeTransUtils transformer(op_desc);
  transformer.Init();
  transformer.CatchFormatAndShape();
  for (auto _ : state) {
    transformer.UpdateFormatAndShape();
  }
}

BENCHMARK_TEST(TransformerCacheAndUpdate) {
  auto graph = BuildGraph();
  graph = BuildGraph();
  auto op_desc = graph->GetNodeByIndex(0)->GetOpDesc();

  NodeShapeTransUtils transformer(op_desc);
  for (auto _ : state) {
    transformer.Init();
    transformer.CatchFormatAndShape();
    transformer.UpdateFormatAndShape();
  }
}

BENCHMARK_TEST(InferShape2In1OutByIndex) {
  auto graph = BuildGraph();
  auto node = graph->GetNodeByIndex(0);
  auto op_desc = node->GetOpDesc();
  for (auto _ : state) {
    auto input0 = op_desc->MutableInputDesc(0);
    auto &input0_dims = input0->GetShape().GetDims();
    auto input1 = op_desc->MutableInputDesc(1);
    auto &input1_dims = input0->GetShape().GetDims();
    auto output0 = op_desc->MutableOutputDesc(0);
    output0->SetShape(Shape({1,16,2,2}));
  }
}
BENCHMARK_TEST(InferShape2In1OutByIndex_GetVec) {
  auto graph = BuildGraph();
  auto node = graph->GetNodeByIndex(0);
  auto op_desc = node->GetOpDesc();
  for (auto _ : state) {
    auto input0 = op_desc->MutableInputDesc(0);
    auto input0_dims = input0->GetShape().GetDimsVec();

    auto input1 = op_desc->MutableInputDesc(1);
    auto input1_dims = input0->GetShape().GetDimsVec();

    auto output0 = op_desc->MutableOutputDesc(0);
    output0->SetShape(Shape({1,16,2,2}));
  }
}

BENCHMARK_TEST(ShapeCreate) {
  for (auto _ : state) {
    Shape({1,16,2,2});
  }
}

BENCHMARK_TEST(ShapeCopy) {
  Shape shape1({1,16,2,2});
  for (auto _ : state) {
    auto shape2 = shape1;
  }
}

BENCHMARK_TEST(SharedPtrShapeCreate) {
  for (auto _ : state) {
    ShapeWrapper shared_shape({1,16,2,2});
  }
}

BENCHMARK_TEST(SharedPtrShapeCopy) {
  ShapeWrapper shape1({1,16,2,2});
  for (auto _ : state) {
    auto shape2 = shape1;
  }
}

BENCHMARK_TEST(VecCreate) {
  for (auto _ : state) {
    std::vector<int64_t> shape({1,16,2,2});
  }
}

BENCHMARK_TEST(ArrayCreate) {
  for (auto _ : state) {
    std::array<int64_t, 4> shape({1,16,2,2});
  }
}

bool IsUnknownShape(const Shape &shape) {
  for (auto dim : shape.GetDims()) {
    if (dim < 0) {
      return true;
    }
  }
  return false;
}

BENCHMARK_TEST(PerfIsUnknownShape) {
  Shape shape1({1,16,2,2});
  for (auto _ : state) {
    IsUnknownShape(shape1);
  }
}
}  // namespace ge
